{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "lstm_in_colab.ipynb",
      "version": "0.3.2",
      "views": {},
      "default_view": {},
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python2",
      "display_name": "Python 2"
    }
  },
  "cells": [
    {
      "metadata": {
        "id": "rx1jX7FVZzb2",
        "colab_type": "text"
      },
      "source": [
        "# 介绍\n",
        "这里以我在 Github 的开源LSTM 文本分类项目为例子https://github.com/Jinkeycode/keras_lstm_chinese_document_classification\n",
        "把 `master/data` 目录下的三个文件存放到 Google Drive 上。该示例演示的是对健康、科技、设计三个类别的标题进行分类。\n",
        "\n",
        "# 安装依赖\n",
        "Tensorflow、Numpy、sklearn 在 colab 是自带的就不需要安装了"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {
        "id": "P2FTAmQraB5Y",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          },
          "output_extras": [
            {
              "item_id": 3
            }
          ],
          "base_uri": "https://localhost:8080/",
          "height": 171
        },
        "outputId": "ebeae7f3-be1e-4506-f247-2cea182aabae",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1515226733793,
          "user_tz": -480,
          "elapsed": 6000,
          "user": {
            "displayName": "黄俊杰",
            "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s128",
            "userId": "110224293212919869066"
          }
        }
      },
      "source": [
        "!pip install keras\n",
        "!pip install jieba\n",
        "!pip install h5py\n",
        "\n",
        "import h5py\n",
        "import jieba as jb\n",
        "import numpy as np\n",
        "import keras as krs\n",
        "import tensorflow as tf\n",
        "from sklearn.preprocessing import LabelEncoder"
      ],
      "cell_type": "code",
      "execution_count": 37,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Requirement already satisfied: keras in /usr/local/lib/python2.7/dist-packages\r\n",
            "Requirement already satisfied: six>=1.9.0 in /usr/local/lib/python2.7/dist-packages (from keras)\r\n",
            "Requirement already satisfied: pyyaml in /usr/local/lib/python2.7/dist-packages (from keras)\r\n",
            "Requirement already satisfied: numpy>=1.9.1 in /usr/local/lib/python2.7/dist-packages (from keras)\r\n",
            "Requirement already satisfied: scipy>=0.14 in /usr/local/lib/python2.7/dist-packages (from keras)\n",
            "Requirement already satisfied: jieba in /usr/local/lib/python2.7/dist-packages\n",
            "Requirement already satisfied: h5py in /usr/local/lib/python2.7/dist-packages\n",
            "Requirement already satisfied: numpy>=1.7 in /usr/local/lib/python2.7/dist-packages (from h5py)\n",
            "Requirement already satisfied: six in /usr/local/lib/python2.7/dist-packages (from h5py)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "vKlrxXszag-X",
        "colab_type": "text"
      },
      "source": [
        "# 加载数据"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {
        "id": "Ce9chUz6ak_E",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          },
          "output_extras": [
            {
              "item_id": 1
            }
          ],
          "base_uri": "https://localhost:8080/",
          "height": 154
        },
        "outputId": "49cbfe64-2b38-460d-de2c-8a5dfbc7e43d",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1515226446807,
          "user_tz": -480,
          "elapsed": 3967,
          "user": {
            "displayName": "黄俊杰",
            "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s128",
            "userId": "110224293212919869066"
          }
        }
      },
      "source": [
        "# 安装 PyDrive 操作库，该操作每个 notebook 只需要执行一次\n",
        "!pip install -U -q PyDrive\n",
        "from pydrive.auth import GoogleAuth\n",
        "from pydrive.drive import GoogleDrive\n",
        "from google.colab import auth\n",
        "from oauth2client.client import GoogleCredentials\n",
        "\n",
        "def login_google_drive():\n",
        "  # 授权登录，仅第一次的时候会鉴权\n",
        "  auth.authenticate_user()\n",
        "  gauth = GoogleAuth()\n",
        "  gauth.credentials = GoogleCredentials.get_application_default()\n",
        "  drive = GoogleDrive(gauth)\n",
        "  return drive\n",
        "\n",
        "def list_file(drive):\n",
        "  file_list = drive.ListFile({'q': \"'root' in parents and trashed=false\"}).GetList()\n",
        "  for file1 in file_list:\n",
        "    print('title: %s, id: %s, mimeType: %s' % (file1['title'], file1['id'], file1[\"mimeType\"]))\n",
        "    \n",
        "\n",
        "drive = login_google_drive()\n",
        "list_file(drive)"
      ],
      "cell_type": "code",
      "execution_count": 28,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "title: tech.txt, id: 14sDl4520Tpo1MLPydjNBoq-QjqOKk9t6, mimeType: text/plain\n",
            "title: health.txt, id: 117GkBtuuBP3wVjES0X0L4wVF5rp5Cewi, mimeType: text/plain\n",
            "title: design.txt, id: 1J4lndcsjUb8_VfqPcfsDeOoB21bOLea3, mimeType: text/plain\n",
            "title: iris, id: 1M3o-kSs59l0PqLNPmd3XPKyMWxRJG8-vvRtUjHVNpAY, mimeType: application/vnd.google-apps.spreadsheet\n",
            "title: iris.csv, id: 1SM_fLhCcYRsGxgHproAQ1RLrJAZ_Qcem, mimeType: text/csv\n",
            "title: Colab Notebooks, id: 1U9363AsQAlJTP2nSeoVae9zDKSsKj5Jj, mimeType: application/vnd.google-apps.folder\n",
            "title: dped.gz, id: 0BwOLOmqkYj-jeUJwQjRNUFkzOTA, mimeType: application/gzip\n",
            "title: models+code.gz, id: 0BwOLOmqkYj-jcnZpaDR0dU9XMm8, mimeType: application/x-gzip\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "7F5_wKqucs3k",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          },
          "output_extras": [
            {
              "item_id": 1
            }
          ],
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "outputId": "62b91d39-1948-48e5-fa94-3805415b6cca",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1515226451119,
          "user_tz": -480,
          "elapsed": 3388,
          "user": {
            "displayName": "黄俊杰",
            "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s128",
            "userId": "110224293212919869066"
          }
        }
      },
      "source": [
        "def cache_data():\n",
        "  # id 替换成上一步读取到的对应文件 id\n",
        "  health_txt = drive.CreateFile({'id': \"117GkBtuuBP3wVjES0X0L4wVF5rp5Cewi\"}) \n",
        "  tech_txt = drive.CreateFile({'id': \"14sDl4520Tpo1MLPydjNBoq-QjqOKk9t6\"})\n",
        "  design_txt = drive.CreateFile({'id': \"1J4lndcsjUb8_VfqPcfsDeOoB21bOLea3\"})\n",
        "  #这里的下载操作只是缓存，不会在你的Google Drive 目录下多下载一个文件\n",
        "  \n",
        "  health_txt.GetContentFile('health.txt', \"text/plain\")\n",
        "  tech_txt.GetContentFile('tech.txt', \"text/plain\")\n",
        "  design_txt.GetContentFile('design.txt', \"text/plain\")\n",
        "  \n",
        "  print(\"缓存成功\")\n",
        "  \n",
        "cache_data()"
      ],
      "cell_type": "code",
      "execution_count": 29,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "缓存成功\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "ts21ooqia3-G",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          },
          "output_extras": [
            {
              "item_id": 1
            }
          ],
          "base_uri": "https://localhost:8080/",
          "height": 86
        },
        "outputId": "bf72b2f8-4deb-4b53-be7d-4c3de21c8d2c",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1515226454485,
          "user_tz": -480,
          "elapsed": 1253,
          "user": {
            "displayName": "黄俊杰",
            "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s128",
            "userId": "110224293212919869066"
          }
        }
      },
      "source": [
        "def load_data():\n",
        "    titles = []\n",
        "    print(\"正在加载健康类别的数据...\")\n",
        "    with open(\"health.txt\", \"r\") as f:\n",
        "        for line in f.readlines():\n",
        "            titles.append(line.strip())\n",
        "\n",
        "    print(\"正在加载科技类别的数据...\")\n",
        "    with open(\"tech.txt\", \"r\") as f:\n",
        "        for line in f.readlines():\n",
        "            titles.append(line.strip())\n",
        "\n",
        "\n",
        "    print(\"正在加载设计类别的数据...\")\n",
        "    with open(\"design.txt\", \"r\") as f:\n",
        "        for line in f.readlines():\n",
        "            titles.append(line.strip())\n",
        "\n",
        "    print(\"一共加载了 %s 个标题\" % len(titles))\n",
        "\n",
        "    return titles\n",
        "  \n",
        "titles = load_data()"
      ],
      "cell_type": "code",
      "execution_count": 30,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "正在加载健康类别的数据...\n",
            "正在加载科技类别的数据...\n",
            "正在加载设计类别的数据...\n",
            "一共加载了 31318 个标题\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "lZiWADs6egOG",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          },
          "output_extras": [
            {
              "item_id": 1
            }
          ],
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "outputId": "8c681728-6c95-4649-9268-2bdd865d8f80",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1515226459694,
          "user_tz": -480,
          "elapsed": 1128,
          "user": {
            "displayName": "黄俊杰",
            "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s128",
            "userId": "110224293212919869066"
          }
        }
      },
      "source": [
        "def load_label():\n",
        "    arr0 = np.zeros(shape=[12000, ])\n",
        "    arr1 = np.ones(shape=[12000, ])\n",
        "    arr2 = np.array([2]).repeat(7318)\n",
        "    target = np.hstack([arr0, arr1, arr2])\n",
        "    print(\"一共加载了 %s 个标签\" % target.shape)\n",
        "\n",
        "    encoder = LabelEncoder()\n",
        "    encoder.fit(target)\n",
        "    encoded_target = encoder.transform(target)\n",
        "    dummy_target = krs.utils.np_utils.to_categorical(encoded_target)\n",
        "\n",
        "    return dummy_target\n",
        "  \n",
        "target = load_label()"
      ],
      "cell_type": "code",
      "execution_count": 31,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "一共加载了 31318 个标签\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "VDVL5T0Vl_AB",
        "colab_type": "text"
      },
      "source": [
        "# 文本预处理"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {
        "id": "i8nIzHJje5ON",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          },
          "output_extras": [
            {
              "item_id": 1
            }
          ],
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "outputId": "d19fe3b3-7220-47cd-841b-3ab411e3f4e4",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1515226467760,
          "user_tz": -480,
          "elapsed": 3673,
          "user": {
            "displayName": "黄俊杰",
            "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s128",
            "userId": "110224293212919869066"
          }
        }
      },
      "source": [
        "max_sequence_length = 30\n",
        "embedding_size = 50\n",
        "\n",
        "# 标题分词\n",
        "titles = [\".\".join(jb.cut(t, cut_all=True)) for t in titles]\n",
        "\n",
        "# \u0003word2vec 词袋化\n",
        "vocab_processor = tf.contrib.learn.preprocessing.VocabularyProcessor(max_sequence_length, min_frequency=1)\n",
        "text_processed = np.array(list(vocab_processor.fit_transform(titles)))\n",
        "\n",
        "# 读取词标签\n",
        "dict = vocab_processor.vocabulary_._mapping\n",
        "sorted_vocab = sorted(dict.items(), key = lambda x : x[1])"
      ],
      "cell_type": "code",
      "execution_count": 32,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "19321\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "duL50JxlmCgT",
        "colab_type": "text"
      },
      "source": [
        "# 构建神经网络\n",
        "这里使用 Embedding 和 lstm 作为前两层，通过 softmax 激活输出结果"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {
        "id": "9GKNjklvfDa5",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          },
          "output_extras": [
            {
              "item_id": 78
            },
            {
              "item_id": 142
            }
          ],
          "base_uri": "https://localhost:8080/",
          "height": 395
        },
        "outputId": "e069ba9a-42f3-4fa9-a834-6c8916a0f3be",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1515227365778,
          "user_tz": -480,
          "elapsed": 77317,
          "user": {
            "displayName": "黄俊杰",
            "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s128",
            "userId": "110224293212919869066"
          }
        }
      },
      "source": [
        "# 配置网络结构\n",
        "def build_netword(num_vocabs):\n",
        "    # 配置网络结构\n",
        "    model = krs.Sequential()\n",
        "    model.add(krs.layers.Embedding(num_vocabs, embedding_size, input_length=max_sequence_length))\n",
        "    model.add(krs.layers.LSTM(32, dropout=0.2, recurrent_dropout=0.2))\n",
        "    model.add(krs.layers.Dense(3))\n",
        "    model.add(krs.layers.Activation(\"softmax\"))\n",
        "    model.compile(loss=\"categorical_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"])\n",
        "\n",
        "    return model\n",
        "  \n",
        "num_vocabs = len(dict.items())\n",
        "model = build_netword(num_vocabs=num_vocabs)\n",
        "\n",
        "import time\n",
        "start = time.time()\n",
        "# 训练模型\n",
        "model.fit(text_processed, target, batch_size=512, epochs=10, )\n",
        "finish = time.time()\n",
        "print(\"训练耗时：%f 秒\" %(finish-start))\n",
        "# 保存模型\n",
        "# model.save(\"health_and_tech_design.h5\") 不知道为何安装了 h5py 之后，colab 依然提示缺少 h5py\n",
        "\n",
        "# 加载预训练的模型\n",
        "# model.load_weights(\"health_and_tech_design.h5\")"
      ],
      "cell_type": "code",
      "execution_count": 39,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch 1/10\n",
            "31318/31318 [==============================] - 8s 259us/step - loss: 1.0316 - acc: 0.4667\n",
            "Epoch 2/10\n",
            "31318/31318 [==============================] - 7s 235us/step - loss: 0.5546 - acc: 0.7649\n",
            "Epoch 3/10\n",
            "31318/31318 [==============================] - 7s 236us/step - loss: 0.2340 - acc: 0.9349\n",
            "Epoch 4/10\n",
            "31318/31318 [==============================] - 7s 237us/step - loss: 0.1372 - acc: 0.9644\n",
            "Epoch 5/10\n",
            "31318/31318 [==============================] - 7s 236us/step - loss: 0.1049 - acc: 0.9740\n",
            "Epoch 6/10\n",
            " 5632/31318 [====>.........................] - ETA: 6s - loss: 0.0809 - acc: 0.9837"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "31318/31318 [==============================] - 8s 240us/step - loss: 0.0842 - acc: 0.9806\n",
            "Epoch 7/10\n",
            "31318/31318 [==============================] - 7s 236us/step - loss: 0.0654 - acc: 0.9851\n",
            "Epoch 8/10\n",
            "31318/31318 [==============================] - 7s 236us/step - loss: 0.0603 - acc: 0.9866\n",
            "Epoch 9/10\n",
            "31318/31318 [==============================] - 7s 236us/step - loss: 0.0502 - acc: 0.9894\n",
            "Epoch 10/10\n",
            "31318/31318 [==============================] - 7s 235us/step - loss: 0.0442 - acc: 0.9906\n",
            "训练耗时：76.154183 秒\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "3xxJNVl2mS4M",
        "colab_type": "text"
      },
      "source": [
        "# 预测样本\n",
        "sen 可以换成你自己的句子，预测结果为[健康类文章概率, 科技类文章概率, 设计类文章概率], 概率最高的为那一类的文章，但最大概率低于0.8时判定为无法分类的文章。"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {
        "id": "1wwVrx71gIvY",
        "colab_type": "code",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          },
          "output_extras": [
            {
              "item_id": 1
            }
          ],
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "outputId": "1af20d9b-a8c1-4d51-9632-82f4dcca4e25",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1515227387283,
          "user_tz": -480,
          "elapsed": 1160,
          "user": {
            "displayName": "黄俊杰",
            "photoUrl": "https://lh3.googleusercontent.com/a/default-user=s128",
            "userId": "110224293212919869066"
          }
        }
      },
      "source": [
        "sen = \"做好商业设计需要学习的小技巧\"\n",
        "sen_prosessed = \" \".join(jb.cut(sen, cut_all=True))\n",
        "sen_prosessed = vocab_processor.transform([sen_prosessed])\n",
        "sen_prosessed = np.array(list(sen_prosessed))\n",
        "result = model.predict(sen_prosessed)\n",
        "\n",
        "catalogue = list(result[0]).index(max(result[0]))\n",
        "threshold=0.8\n",
        "if max(result[0]) > threshold:\n",
        "    if catalogue == 0:\n",
        "        print(\"这是一篇关于健康的文章\")\n",
        "    elif catalogue == 1:\n",
        "        print(\"这是一篇关于科技的文章\")\n",
        "    elif catalogue == 2:\n",
        "        print(\"这是一篇关于设计的文章\")\n",
        "    else:\n",
        "        print(\"这篇文章没有可信分类\")"
      ],
      "cell_type": "code",
      "execution_count": 40,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "这是一篇关于设计的文章\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}